import math

import torch
from torch import nn
from torch.nn import functional as F
import torchvision
from transformers import RobertaTokenizer, RobertaModel

from models.bottlenecks import DownSample, UpSample

import ipdb
st = ipdb.set_trace


class AnalogicalTransformer(nn.Module):
    """Analogical correspondence network."""

    def __init__(self, in_dim=6, out_dim=256, num_layers=3, num_query=None,
                 hybrid=False, use_x1=True):
        super().__init__()
        self.hybrid = hybrid
        self.num_query = num_query
        self.use_x1 = use_x1

        # Bottlenecks
        self.down = DownSample(in_dim, out_dim)
        self.up = UpSample(out_dim, out_dim)

        # Language encoder
        self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
        self.pre_encoder = RobertaModel.from_pretrained('roberta-base').eval()
        for param in self.pre_encoder.parameters():
            param.requires_grad = False
        self.text_proj = nn.Linear(768, out_dim)

        # Special tokens (X1, L1, Y1, X2, L1, Y2)
        self.sp_tokens = nn.Embedding(6, out_dim)

        # Positional embeddings
        self.pos_emb = PositionEmbeddingSine(out_dim // 2, normalize=True)

        # "Object" queries
        if num_query is not None:
            self.queries = nn.Embedding(num_query, out_dim)
        else:
            self.queries = None
        self.obj_heads = nn.ModuleList([nn.Sequential(
            nn.Conv1d(out_dim, out_dim, 1, bias=False),
            nn.BatchNorm1d(out_dim), nn.ReLU(), nn.Dropout(0.3),
            nn.Conv1d(out_dim, out_dim, 1, bias=False),
            nn.BatchNorm1d(out_dim), nn.ReLU(), nn.Dropout(0.3),
            nn.Conv1d(out_dim, 4, 1)
        ) for _ in range(num_layers)])

        # Transformer layers
        self.decoder = nn.ModuleList([nn.TransformerDecoderLayer(
            out_dim, nhead=8, dim_feedforward=2 * out_dim
        ) for _ in range(num_layers)])

    def forward(self, im_1, lang_1, box_1, act_1, mask_1, im_2, lang_2):
        """Forward pass, im (B, 6, H, W), box (B, 8, 4), act (B, 2, 2)."""
        # Encode scenes
        lat_1 = self.down(im_1)
        lat_2 = self.down(im_2)
        feats_1 = self.up(lat_1)
        feats_2 = self.up(lat_2)

        # Encode language
        lang_1, pad_1 = self._encode_lang(lang_1)
        lang_2, pad_2 = self._encode_lang(lang_2)

        # Add positional embeddings
        lat_1 = lat_1 + self.pos_emb(lat_1)
        lat_2 = lat_2 + self.pos_emb(lat_2)
        lat_1 = lat_1.flatten(2).transpose(1, 2)
        lat_2 = lat_2.flatten(2).transpose(1, 2)

        # Featurize clusters (B, n_clusters, F)
        queries_1, queries_2 = self._featurize(feats_1, box_1, act_1)

        # Padding
        pad = self._construct_pad(pad_1, pad_2, lat_1, mask_1)

        # Decode
        return self._decode(
            lat_1, lang_1, queries_1, lat_2, lang_2, queries_2, feats_2, pad
        )

    def _construct_pad(self, pad_1, pad_2, lat_1, mask_1):
        mem_pad = torch.cat((
            torch.zeros(len(lat_1), lat_1.size(1)).bool().to(lat_1.device),
            ~pad_1.bool(),
            ~mask_1.bool(),
            torch.zeros(len(lat_1), 2).bool().to(lat_1.device),
            torch.zeros(len(lat_1), lat_1.size(1)).bool().to(lat_1.device),
            ~pad_2.bool()
        ), 1)
        return mem_pad

    def _encode_lang(self, lang):
        device = next(self.pre_encoder.parameters()).device
        inputs = self.tokenizer(lang, return_tensors="pt", padding=True)
        lang = {key: val.to(device) for key, val in inputs.items()}
        with torch.no_grad():
            pad = lang['attention_mask']
            lang = self.pre_encoder(**lang)
        return self.text_proj(lang.last_hidden_state), pad

    def _pool_feats(self, scene, objs):
        return torchvision.ops.roi_align(scene, [obj for obj in objs], 1)

    def _featurize(self, feats_1, box_1, act_1):
        # returns (B, 4, F)
        act_1 = torch.cat((act_1 - 1, act_1 + 1), -1)
        boxes = torch.cat((box_1, act_1), 1)
        queries_1 = self._pool_feats(feats_1, boxes)
        queries_1 = queries_1.reshape(len(boxes), boxes.size(1), -1)

        if self.hybrid:
            queries_2 = torch.cat((
                self.queries.weight[None].repeat(len(feats_1), 1, 1),
                queries_1
            ), 1)
        elif self.queries is not None:
            # for point cloud 2, initialize learnable queries
            queries_2 = self.queries.weight[None].repeat(len(feats_1), 1, 1)
        else:
            # for point cloud 2, initialize with queries_1
            queries_2 = queries_1.detach().clone()
            # gradients are not back-propagated to queries_2 in this case

        return queries_1, queries_2

    def _decode(self, latents_1, lang_1, queries_1,
                latents_2, lang_2, queries_2, feats_2, pad):
        # Construct keys
        _embs = self.sp_tokens.weight
        keys = torch.cat((
            latents_1 + _embs[0][None, None],
            lang_1 + _embs[1][None, None],
            queries_1 + _embs[2][None, None],
            latents_2 + _embs[3][None, None],
            lang_2 + _embs[4][None, None],
        ), 1).transpose(0, 1)

        scores, boxes = [], []
        for k, layer in enumerate(self.decoder):
            # feed to decoder layer
            queries_2 = layer(
                queries_2.transpose(0, 1),
                keys,
                memory_key_padding_mask=pad
            ).transpose(0, 1)

            # score every point (B, 2, H, W)
            scores.append(10 * torch.matmul(
                F.normalize(queries_2[:, -2:], dim=-1),
                F.normalize(feats_2.flatten(2), dim=1)
            ).reshape(len(feats_2), -1, feats_2.size(2), feats_2.size(3)))

            # compute boxes (B, 8, 4)
            boxes.append(self.obj_heads[k](
                queries_2[:, :-2].transpose(1, 2)
            ).transpose(1, 2))
        return scores, boxes


class PositionEmbeddingSine(nn.Module):
    """
    This is a more standard version of the position embedding,
    very similar to the one
    used by the Attention is all you need paper,
    generalized to work on images.
    """

    def __init__(self, num_pos_feats=64, temperature=10000, normalize=False):
        super().__init__()
        self.num_pos_feats = num_pos_feats
        self.temperature = temperature
        self.normalize = normalize
        self.scale = 2 * math.pi

    def forward(self, x):
        """Image x (B, F, H, W)."""
        not_mask = torch.ones_like(x[:, 0])
        y_embed = not_mask.cumsum(1, dtype=torch.float32)
        x_embed = not_mask.cumsum(2, dtype=torch.float32)
        if self.normalize:
            eps = 1e-6
            y_embed = (y_embed - 0.5) / (y_embed[:, -1:, :] + eps) * self.scale
            x_embed = (x_embed - 0.5) / (x_embed[:, :, -1:] + eps) * self.scale

        dim_t = torch.arange(self.num_pos_feats, device=x.device).float()
        dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)

        pos_x = x_embed[:, :, :, None] / dim_t
        pos_y = y_embed[:, :, :, None] / dim_t
        pos_x = torch.stack((
            pos_x[:, :, :, 0::2].sin(),
            pos_x[:, :, :, 1::2].cos()
        ), dim=4).flatten(3)
        pos_y = torch.stack((
            pos_y[:, :, :, 0::2].sin(),
            pos_y[:, :, :, 1::2].cos()
        ), dim=4).flatten(3)
        pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
        return pos
